{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import xml.etree.ElementTree as ET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# common parameters\n",
    "class_index = ('mask', 'nomask')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train dataset parameters\n",
    "images_path = 'datasets/train/JPEGImages'\n",
    "annotations_path = 'datasets/train/Annotations'\n",
    "# get train dataset list\n",
    "images = sorted(os.listdir(images_path))\n",
    "annotations = sorted(os.listdir(annotations_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make train annotations\n",
    "with open('model_data/mask_train.txt', 'w+') as f:\n",
    "    for i in range(len(annotations)):\n",
    "        annotation_path = annotations[i]\n",
    "        image_path = images[i]\n",
    "        f.write(images_path + '/' + image_path + ' ')\n",
    "        \n",
    "        tree = ET.parse(annotations_path + '/' + annotation_path)\n",
    "        objects = tree.findall('object')\n",
    "        for obj in objects:\n",
    "            name = obj.find('name').text\n",
    "            bndbox = obj.find('bndbox')\n",
    "            f.write(str(bndbox.find('xmin').text) + ',' +\n",
    "                    str(bndbox.find('ymin').text) + ',' +\n",
    "                    str(bndbox.find('xmax').text) + ',' +\n",
    "                    str(bndbox.find('ymax').text) + ',' +\n",
    "                    str(class_index.index(name)) + ' ')\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train dataset parameters\n",
    "images_path = 'datasets/val/JPEGImages'\n",
    "annotations_path = 'datasets/val/Annotations'\n",
    "# get train dataset list\n",
    "images = sorted(os.listdir(images_path))\n",
    "annotations = sorted(os.listdir(annotations_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make train annotations\n",
    "with open('model_data/mask_val.txt', 'w+') as f:\n",
    "    for i in range(len(annotations)):\n",
    "        annotation_path = annotations[i]\n",
    "        image_path = images[i]\n",
    "        f.write(images_path + '/' + image_path + ' ')\n",
    "        \n",
    "        tree = ET.parse(annotations_path + '/' + annotation_path)\n",
    "        objects = tree.findall('object')\n",
    "        for obj in objects:\n",
    "            name = obj.find('name').text\n",
    "            bndbox = obj.find('bndbox')\n",
    "            f.write(str(bndbox.find('xmin').text) + ',' +\n",
    "                    str(bndbox.find('ymin').text) + ',' +\n",
    "                    str(bndbox.find('xmax').text) + ',' +\n",
    "                    str(bndbox.find('ymax').text) + ',' +\n",
    "                    str(class_index.index(name)) + ' ')\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
